import os
import glob
import torch
import wandb
import imageio
import pathlib
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import torch.nn.functional as F
from scipy.stats import gaussian_kde
from vis.kde import KernelDensityEstimator as KDE
from datatool.dataset.toy_t import get_true_log_prob


def creat_toy_plt_gif(fpath):
    filenames = glob.glob(fpath + '/epoch_*.png')
    if len(filenames) == 0: return
    filenames.sort(key=lambda x: int(x.replace(fpath, '')[7:-4]))  # e p o c h _ [XXX] .p n g
    images = [imageio.imread(filename) for filename in filenames]
    # use your own path
    gif_path = fpath+'v.gif'
    imageio.mimsave(gif_path, images, duration=0.2)
    os.system('gifsicle --scale 1. -O3 {} -o {} '.format(gif_path, gif_path))
    wandb.log({"video/gif": wandb.Video(gif_path, fps=4, format="gif")})
    print('Done.')


class ToyPlot:

    def __init__(self, dataset_name, path, n_cls, dataset):
        self.dataset_name = dataset_name
        self.spath = os.path.join(path, 'images')
        pathlib.Path(self.spath).mkdir(parents=True, exist_ok=True)
        self.n_cls = n_cls
        self.n_pts = 1001
        self.mid = int((self.n_pts-1)/2)
        self.range_lim = 8
        self.test_grid = self._setup_grid()
        self.X = dataset.tensors[0]
        self.y = dataset.tensors[1]
        self.data = self.X.numpy()
        self.is_true = 'gaussians' in self.dataset_name
        self.log_prob = self._log_prob_true if self.is_true else self._log_prob_kde

    @torch.no_grad()
    def _setup_grid(self):
        self.linsp = torch.linspace(-self.range_lim, self.range_lim, self.n_pts)
        xx, yy = torch.meshgrid((self.linsp, self.linsp))
        zz = torch.stack((xx.flatten(), yy.flatten()), dim=1)
        return xx, yy, zz.cuda()

    @torch.no_grad()
    def _log_prob_kde(self, zz):
        kdes = [KDE(self.XbyY[y].cuda()) for y in range(self.n_cls)]
        log_z_kde_splitted = [[kde(zp) for y in range(self.n_cls)] \
                              for zp in torch.split(zz, 1001)]
        log_z_kde_j = torch.cat(log_z_kde_splitted, dim=0).transpose(0, 1)
        log_z_kde_c = F.log_softmax(log_z_kde_j, dim=-1)
        return log_z_kde_j, log_z_kde_c

    @torch.no_grad()
    def _log_prob_true(self, zz):
        logits_j = torch.cat([get_true_log_prob(self.dataset_name, zp) \
                                       for zp in torch.split(zz, 1001)],
                                    dim=1).transpose(0, 1)
        logits_c = F.log_softmax(logits_j, dim=-1)
        return logits_j, logits_c

    @torch.no_grad()
    def plot(self, model, epoch, it, x_q, x_r, f='model'): # , noise
        _, _, zz = self.test_grid
        logits_j, logits_c = model.jc(zz) if f == 'model' else self.log_prob(zz)
        me = -torch.logsumexp(logits_j, -1)
        _exp_me = (-me).exp()
        mp = _exp_me / _exp_me.sum()
        _exp_je = (logits_j).exp()
        jp = _exp_je / _exp_je.sum() # jp = mp * cp
        je = -logits_j
        cp = F.softmax(logits_c, dim=-1)
        # ce = -F.log_softmax(logits_c, dim=-1) #je-me.unsqueeze(-1)#
        ce = -logits_c #je-me.unsqueeze(-1)#
        vals = [mp, me, jp, je, cp, ce]
        mp, me, jp, je, cp, ce = [v.cpu() for v in vals]
        fig, axs = plt.subplots((self.n_cls+1), 8, figsize=(8*3.5, (self.n_cls+1)*3))
                                # subplot_kw={'aspect': 'equal'})
        self._plot_1n2(fig, axs, 0, 0, mp, 'Marginal Prob.')
        self._plot_1n2(fig, axs, 2, 0, me, 'Marginal Energy')
        if f=='model' and x_q is not None: self._plot_scatter_2d(axs, 1, 0, x_q.detach().cpu().numpy(), 'q samples')
        elif f=='true': self._plot_scatter_2d(axs, 1, 0, self.data, 'Data')
        if f=='model' and x_r is not None: self._plot_scatter_2d(axs, 3, 0, x_r.detach().cpu().numpy(), 'r samples')
        for j in range(self.n_cls):
            self._plot_1n2(fig, axs, 0, 1+j, jp[:, j], 'Joint Prob. of {}'.format(j))
            self._plot_1n2(fig, axs, 2, 1+j, je[:, j], 'Joint Energy of {}'.format(j))
            self._plot_1n2(fig, axs, 1, 1+j, cp[:, j], 'Cond. Prob. of {}'.format(j))
            self._plot_1n2(fig, axs, 3, 1+j, ce[:, j], 'Cond. Energy of {}'.format(j))
        fig.suptitle('epoch_{}_{}'.format(epoch, it))
        fig.tight_layout(rect=[0, 0.03, 1, 0.95])
        # plt.tight_layout()
        fname = 'epoch.png' if f=='model' else 'true.png'
        # fname = 'epoch_{}.png'.format(epoch) if f=='model' else 'true.png'
        wandb.log({'energy/{}'.format(fname) : wandb.Image(plt)})
        sfile = os.path.join(self.spath, fname)
        plt.savefig(sfile)
        plt.close()
        # if epoch > 2 and epoch % 5 == 0:
            # creat_toy_plt_gif(self.spath)

    def _plot_1n2(self, fig, axs, i, j, val, title):
        val = val.view(self.n_pts,self.n_pts)
        self._plot_1d(axs, i, j, val, title)
        self._plot_2d(fig, axs, i+4, j, val, title)

    def _plot_1d(self, axs, i, j, val, title):
        val_1d = val[:, self.mid]
        axs[j, i].plot(self.linsp, val_1d)
        axs[j, i].set_title(title)
        axs[j, i].set_xlim(-self.range_lim, self.range_lim)

    def _plot_2d(self, fig, axs, i, j, val, title):
        xx, yy, _ = self.test_grid
        im = axs[j, i].pcolormesh(xx, yy, val, cmap=plt.cm.jet, shading='auto')
        axs[j, i].set_facecolor(plt.cm.jet(0.))
        axs[j, i].set_title(title)
        fig.colorbar(im, ax=axs[j, i])
        # axs[i, j].contourf(xx, yy, val, 4, colors='k')
        self._format_ax(axs[j, i])

    def _plot_scatter_2d(self, axs, i, j, sample, title):
        axs[j, i].scatter(sample[:,0], sample[:,1],
                marker='.', alpha=0.03,
                cmap=plt.cm.jet)
        axs[j, i].set_title(title)
        self._format_ax(axs[j, i])

    def _plot_hist_2d(self, axs, i, j, data, title):
        axs[j, i].hist2d(data[:,0], data[:,1],
                range=[[-self.range_lim, self.range_lim], [-self.range_lim, self.range_lim]],
                bins=self.n_pts, cmap=plt.cm.jet)
        axs[j, i].set_title(title)
        self._format_ax(axs[j, i])

    def _format_ax(self, ax, two_d=True):
        ax.set_xlim(-self.range_lim, self.range_lim)
        if two_d: ax.set_ylim(-self.range_lim, self.range_lim)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.invert_yaxis()
